

import torch
import torch.nn.utils.prune as prune
import time
import sys
from Tx import *
from channel import *
from utils import *
from simNeuralEQ import *
import device
import argparse


def pruning(model, pruneMethod, pruneRatio):
	parametersToPrune = []
	for moduleName, module in model.named_modules():
		if isinstance(module, torch.nn.Linear):
			parametersToPrune.append((module, "weight"))
	prune.global_unstructured(
		parametersToPrune,
		pruning_method=pruneMethod,
		amount=pruneRatio,
	)
	if 0:
		print(model.nnUnit[0][0][0])
		num_zeros, num_elements, sparsity = measure_module_sparsity(model.nnUnit[0][0][0], weight=True, use_mask=True)
		print(f"num_zeros: {num_zeros}, num_elements:{num_elements}, sparsity:{sparsity}")
		sys.exit()



def iterPrune(
	model,			#@@ Pre-trained model to be pruned
	pruneMethod,	#@@ Pruning method
	pruneRatio,		#@@ Pruning ratio

	numIterPrune,	#@@ Number of pruning
	numFineTune,	#@@ Number of fine tuning for each pruning

	snrValid,		#@@ SNR for validation data set
	flagN,			#@@ Noise switch
	chSBR,			#@@ Channel SBR

	mod,
	inSize,
	outSize,
	batchSize,
	delay,
	lossFn,
	opt,
	dataSizeTrain,
	snrTrain,
	simName,

	):
	tx = Tx(mod=mod)

	#@@ Valid sequence for training 
	chInValid = tx.run(int(cfg['prune']['dataSizeValid']))
	chValid = Channel(sbr=chSBR, snr=snrValid)
	chOutValid = chValid.run(chIn = chInValid, flagN=flagN)
	#opt = torch.optim.Adam(nEqLoad.parameters(), lr=lrInit, weight_decay=weightDecay)#1e-5)

	'''******************************
	Initially check the trainloss
	******************************'''
	_, __, berValid = trainEval(	
		model,
		tx,
		chInValid,
		chOutValid,
		10,
		10,
		mod,
		chSBR,
		inSize,
		outSize,
		batchSize,
		delay,
		lossFn,
		opt,
		dataSizeTrain,
		snrTrain,
		flagN,
		)



	numZerosHis=[]
	numElementsHis=[]
	sparsityHis=[]
	numZerosFinalHis=[]
	numElementsFinalHis=[]
	sparsityFinalHis=[]
	berValidHis = []
	'''**********************************
	For loop for prune and fine-tuning
	**********************************'''
	for k in range(numIterPrune):
		print("")
		print(f"pruneIter: {k}")
		'''****************************
		Pruning
		****************************'''
		pruning(model, pruneMethod, pruneRatio)


		'''****************************
		Traning (fine-tuning)
		****************************'''
		_, __, berValid = trainEval(	
					model,
					tx,
					chInValid,
					chOutValid,
					numFineTune,
					numFineTune,
					mod,
					chSBR,
					inSize,
					outSize,
					batchSize,
					delay,
					lossFn,
					opt,
					dataSizeTrain,
					snrTrain,
					flagN,
					)

		numZerosList, numElementsList, sparsityList, numZerosFinalList, numElementsFinalList, sparsityFinalList = model.measureNnSeqSparsity(weight=True,bias=False,useMask=True) 
		numZerosHis.append(numZerosList)
		numElementsHis.append(numElementsList)
		sparsityHis.append(sparsityList)
		numZerosFinalHis.append(numZerosFinalList)
		numElementsFinalHis.append(numElementsFinalList)
		sparsityFinalHis.append(sparsityFinalList)
		berValidHis.append(berValid[0])
		if 0:
			print(f"berValid[0]: {berValid[0]}")
			print(f"berValidHis: {berValidHis}")
			#sys.exit()
		torch.save(model, './results/%s_PRUNE/nEq_%s_prune%d.pt'%(simName,mod,k))

	print("")
	print(f"numZerosHis:\n{numZerosHis}")
	print(f"numElementsHis:\n {numElementsHis}")
	print(f"sparsityHis:\n {sparsityHis}")
	print(f"numZerosFinalHis:\n {numZerosFinalHis}")
	print(f"numElementsFinalHis:\n {numElementsFinalHis}")
	print(f"sparsityFinalHis:\n {sparsityFinalHis}")
	print(f"berValidHis:\n {berValidHis}")
	print("")





if __name__ == "__main__":
	#*************************HEADER***********************#
	startTime =	time.time()
	#np.random.seed(1)
	args = parsing_def()
	sys.path.insert(0, './config')
	config_module =	__import__('config_{}'.format(args.config))
	cfg= config_module.config

	if cfg['prune']['mod'] == 'nrz':
		modNum = 2
	elif cfg['prune']['mod'] ==	'pam4':
		modNum = 4
	elif cfg['prune']['mod'] ==	'pam8':
		modNum = 8
	else:
		sys.exit('invalid modulation')

	delay =	int((cfg['prune']['inSize'])/4)
	delayOffset = -list(cfg['prune']['chSBR']).index(max(cfg['prune']['chSBR']))
	simName = args.name
	#******************************************************#


	'''****************************
	Params
	****************************'''
	pruneMethod = prune.L1Unstructured
	

	'''******************************
	Load pre-trained model
	******************************'''
	nEqLoad = torch.load(cfg['prune']['modelFile'])
	nEqLoad = nEqLoad.to(device.device)

	
	'''******************************
	Pruning & fine tune
	******************************'''

	opt = torch.optim.Adam(nEqLoad.parameters(), lr=cfg['train']['lr'], weight_decay=0)#1e-5)


	iterPrune(
		nEqLoad,						#@@ Pre-trained model to be pruned
		pruneMethod,					#@@ Pruning method
		cfg['prune']['pruneRatio'],		#@@ Pruning ratio

		cfg['prune']['numIterPrune'],	#@@ Number of pruning
		cfg['prune']['numFineTune'],	#@@ Number of fine tuning for each pruning

		cfg['prune']['snrValid'],		#@@ SNR for validation data set
		cfg['prune']['noiseFlag'],		#@@ Noise switch
		cfg['prune']['chSBR'],			#@@ Channel SBR

		cfg['prune']['mod'],
		cfg['prune']['inSize'],
		cfg['prune']['outSize'],
		cfg['prune']['batchSize'],
		delay+delayOffset,
		cfg['prune']['lossFn'],
		opt,
		int(cfg['prune']['dataSizeTrain']),
		cfg['prune']['snrTrain'],
		simName,
		)

	timeSim	= (time.time()-startTime)/60. #	Unit: minuite
	print(f"Total simulation time: {timeSim} mins")
